import os
import os.path as osp
import pickle

import gym
from omegaconf import OmegaConf

import diffgro
from diffgro.utils.config import load_config
from diffgro.utils import make_dir, print_y, print_b
from train import *


def simulate_traj(policy, env: gym.Env):
    obs, done = env.reset(), False
    actions = []
    observations = []
    if getattr(policy, "reset", None) is not None:
        policy.reset()

    while not done:
        action, _, _ = policy.predict(obs, deterministic=True)
        action = np.array(action.copy())
        obs, _, done, e_info = env.step(action)

        observations.append(obs)
        actions.append(action)

    if e_info["success"]:
        observations = np.array(observations)
        actions = np.array(actions)
        return {"actions": actions, "observations": observations}
    else:
        return None


def collect_dataset(args):
    print_y("Loading ... DiffGroPlanner ...")
    planner = diffgro.DiffGroPlanner.load(args.model_path + "/planner")
    config = load_config("./config/algos/diffgro.yml", args.domain_name)

    for task in args.task_list:
        args.env_name = f"{args.domain_name}.{task}"
        env, _, _ = make_env(args)
        history = config["planner"]["inference"]["history"]
        model = diffgro.DiffGro(
            env,
            planner,
            history=history,
            delta=0,
            guide="test",
            guide_pt="faster",
        )

        for guide_pt in [
            "x_faster",
            "x_slower",
            "y_faster",
            "y_slower",
            "slower",
            "faster",
            None,
        ]:
            delta, done = 0, False
            save_path = osp.join(
                "visualizations",
                "metaworld",
                task,
                guide_pt if guide_pt is not None else "default",
            )
            make_dir(save_path)

            while not done:
                if guide_pt is None:
                    model.delta = 0
                    model.guide_pt = "faster"
                else:
                    model.delta = delta
                    model.guide_pt = guide_pt
                model._setup_guide()
                print_y(
                    f"Collecting dataset for {task} with guide {guide_pt} and delta {delta:.4f}"
                )

                if guide_pt == "faster" and delta > 0.40:
                    break

                num_episodes = 1 if guide_pt is not None else 40
                for episode in range(num_episodes):
                    traj = simulate_traj(model, env)
                    if traj is None:
                        done = True
                        break

                    with open(
                        osp.join(save_path, f"{delta:.4f}_{episode}.pkl"), "wb"
                    ) as f:
                        pickle.dump(traj, f)

                if guide_pt is None:
                    break
                delta += 0.005


if __name__ == "__main__":
    args = OmegaConf.create(
        {
            "domain_name": "metaworld",
            "env_name": None,
            "goal_resistance": 0,
            "seed": 777,
            "model_path": "models/diffgro",
            "task_list": [
                # "drawer-close-variant-v2",
                # "button-press-variant-v2",
                "drawer-open-variant-v2",
                "door-open-variant-v2",
                "faucet-open-variant-v2",
                "window-open-variant-v2",
                "window-close-variant-v2",
                "push-variant-v2",
                "pick-place-variant-v2",
                "peg-insert-side-variant-v2",
            ],
        }
    )
    collect_dataset(args)
